#! /usr/bin/env python
#coding=utf-8

import os
import sqlite3

class ElfModule(dict):
	def __eq__(self, other):
		if not isinstance(other, ElfModule):
			return NotImplemented

		return self["id"] == other["id"]#and self["name"] == other["name"]

	def __init__(self, row, mgr):
		cursor = mgr.getCursor()
		names = [description[0] for description in cursor.description]
		for idx, name in enumerate(names):
			self[name] = row[idx]

		self["chipsetsdk_dependedBy"] = self["chipsetsdk_DependedBy"]
		del self["chipsetsdk_DependedBy"]

		self["platformsdk_dependedBy"] = self["platformsdk_DependedBy"]
		del self["platformsdk_DependedBy"]

		if "labelPath" not in names:
			self["labelPath"] = ""
		if "shlib_type" not in names:
			self["shlib_type"] = ""
		if "innerapi_tags" not in names:
			self["innerapi_tags"] = ""

		pos = self["labelPath"].find("(")
		if pos > 0:
			self["labelPath"] = self["labelPath"][:pos]

		if self["type"]:
			self["type"] = "bin"
		else:
			self["type"] = "lib"

		self["unused"] = self["provided"] - self["used"]

		self["deps"] = []
		self["dependedBy"] = []
		self["human_size"] = ElfModule.__human_size(self["size"])
		self._symbolClass = ElfSymbol

		self._mgr = mgr

	def getAllDependedModules(self):
		res = []
		for dep in self["deps"]:
			res.append(dep["callee"])
		for dep in self.getDepsIndirect():
			res.append(dep["callee"])
		return res

	@staticmethod
	def __human_size(bytes, units=['B','KB','MB','GB','TB', 'PB', 'EB']):
		return str(bytes) + units[0] if bytes < 1024 else ElfModule.__human_size(bytes>>10, units[1:])

	def isLibrary(self):
		if self["type"] == "lib":
			return True
		return False

	def dependsOn(self, mod):
		for dep in self["deps"]:
			if dep.getCallee() == mod:
				return True
		return False

	def getDepByCallee(self, callee):
		for dep in self["deps"]:
			if dep.getCallee() == callee:
				return dep
		return NotImplemented

	def _traverse_deps_tree_by_width(self, children, traversal, cookie, depth):
		hasSkippedItems = False
		for dep in self["deps"]:
			child = dep["callee"]

			# Already been traversaled
			if child in children:
				continue

			stop = traversal(dep, depth, cookie)

			# Add to children
			children.append(child)
			if stop:
				hasSkippedItems = True
				continue

			stop = child._traverse_deps_tree_by_width(children, traversal, cookie, depth + 1)
			if stop:
				hasSkippedItems = True

		return hasSkippedItems

	def traverse_deps_tree_by_width(self, traversal, cookie):
		children = []
		return self._traverse_deps_tree_by_width(children, traversal, cookie, 1)

	def _traverse_deps_tree_by_depth(self, children, deps, traversal, cookie, depth):
		hasSkippedItems = False
		for dep in self["deps"]:
			callee = dep["callee"]

			if dep not in deps:
				deps.append(dep)

			# Already been traversaled
			if callee in children:
				continue

			stop = False

			if traversal:
				stop = traversal(dep, depth, cookie)
			# Add to children
			children.append(callee)

			if stop:
				hasSkippedItems = True
				continue

			stop = callee._traverse_deps_tree_by_depth(children, deps, traversal, cookie, depth + 1)
			if stop:
				hasSkippedItems = True

		return hasSkippedItems

	def traverse_deps_tree_by_depth(self, traversal = None, cookie = None):
		children = []
		deps = []
		self._traverse_deps_tree_by_depth(children, deps, traversal, cookie, 1)
		return (children, deps)

	def _traverse_dependedBy_tree_by_depth(self, parents, deps, traversal, cookie, depth):
		hasSkippedItems = False
		for dep in self["dependedBy"]:
			caller = dep["caller"]

			if dep not in deps:
				deps.append(dep)

			# Already been traversaled
			if caller in parents:
				continue

			stop = False

			if traversal:
				stop = traversal(dep, depth, cookie)
			# Add to parents
			parents.append(caller)

			if stop:
				hasSkippedItems = True
				continue

			stop = caller._traverse_dependedBy_tree_by_depth(parents, deps, traversal, cookie, depth + 1)
			if stop:
				hasSkippedItems = True

		return hasSkippedItems

	def traverse_dependedBy_tree_by_depth(self, traversal = None, cookie = None):
		parents = []
		deps = []
		self._traverse_dependedBy_tree_by_depth(parents, deps, traversal, cookie, 1)
		return (parents, deps)

	def output_deps_tree_by_width(self, maxDepth, minChild):
		if self["deps_total"] < 30:
			minChild = 0
		filebasename = self["name"] + "-by-width-" + str(self["deps_total"])
		outputfile = "tree/" + filebasename + ".txt"
		with open(outputfile, "wb") as f:
			f.write("digraph g {\n")
			self.traverse_deps_tree_by_width(ElfModule.graph_writer, (f, maxDepth, minChild))
			f.write("}\n")
		pngfile = "tree/" + filebasename + ".png"
		os.system("dot -Tpng " + outputfile + " > " + pngfile)
		return pngfile

	def output_deps_tree_by_depth(self, maxDepth):
		filebasename = self["name"] + "-by-depth-" + str(self["deps_total"])
		outputfile = "tree/" + filebasename + ".txt"
		with open(outputfile, "wb") as f:
			f.write("digraph g {\n")
			self.traverse_deps_tree_by_depth(ElfModule.graph_writer, (f, maxDepth))
			f.write("}\n")
		pngfile = "tree/" + filebasename + ".png"
		os.system("dot -Tpng " + outputfile + " > " + pngfile)
		return pngfile

	def getDetailName(self):
		return "%s\n(deps %d,total %d,by %d,%s)" % (self["name"], len(self["deps"]), self["deps_total"], len(self["dependedBy"]), self["human_size"])

	def _getGraphVizColor(self):
		if self["modGroup"] == "publicapi":
			return "purple"
		if self["modGroup"] == "innerapi_chc":
			return "green"
		elif self["modGroup"] == "innerapi_chc_indirect":
			return "gray"
		# Other inner api
		if self["modGroup"] == "innerapi_cc":
			return "cornsilk"

		return "white"

	def _getGraphVizShape(self):
		if self["modGroup"] in ("publicapi", "pentry"):
			return "box3d"
		return "box"

	def getGraphVizInfo(self):
		return '"m%d" [fillcolor=%s, style="rounded,filled", shape=%s, label="%s"];\n' % (self["id"], self._getGraphVizColor(), self._getGraphVizShape(), self.getDetailName())

	@staticmethod
	def graph_writer(dep, depth, cookie):
		f = cookie[0]
		maxDepth = cookie[1]
		stop = False
		minChild = 0
		if len(cookie) > 2:
			minChild = cookie[2]

		if minChild > 0 and len(dep["callee"]["deps"]) < minChild:
			# not enough child, just ignore
			return True

		if maxDepth > 0 and depth + 1 > maxDepth:
			# Ignore next depth children
			stop = True

		callerName = "\"" + dep["caller"].getDetailName() + "\""
		calleeName = "\"" + dep["callee"].getDetailName() + "\""

		f.write(callerName + " -> " + calleeName)
		#if child not in parent["deps"]:
		f.write("[label=\"%s\"]" % (dep["calls"]))
		f.write(";\n")

		if stop and len(dep["callee"]["deps"]) > 0:
			f.write(calleeName + " [ style = filled ];\n")

		return stop

	def getUndefinedSymbols(self, cursor):
		sqlcmd = "select %s from undefines where parent_id=%d" % (self._symbolClass.QUERY_FIELDS_SIMPLE_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getUndefinedSymbolsMatched(self, cursor):
		sqlcmd = "select %s from call_details where caller_id=%d" % (self._symbolClass.QUERY_FIELDS_SIMPLE_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getUndefinedSymbolsUnmatched(self, cursor):
		sqlcmd = "select %s from undefines where parent_id=%d and name not in (select distinct(name) from call_details where caller_id=%d)" % (self._symbolClass.QUERY_FIELDS_SIMPLE_STR, self["id"], self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getUndefinedSymbolsDuplicated(self, cursor):
		sqlcmd = "select %s from call_details where caller_id=%d and name in (select name from call_details where caller_id=%d GROUP by name HAVING count(*)  > 1) ORDER by name" % (self._symbolClass.QUERY_FIELDS_STR, self["id"], self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getProvidedSymbols(self, cursor):
		sqlcmd = "select %s from symbols where parent_id=%d" % (self._symbolClass.QUERY_FIELDS_SIMPLE_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getProvidedSymbolsUsed(self, cursor):
		sqlcmd = "select %s from call_details where parent_id=%d order by name" % (self._symbolClass.QUERY_FIELDS_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getProvidedSymbolsUnused(self, cursor):
		sqlcmd = "select %s from symbols where parent_id=%d and id not in (select symbol_id from calls where callee_id=%d)" % (self._symbolClass.QUERY_FIELDS_SIMPLE_STR, self["id"], self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getDepsIndirect(self):
		sqlcmd = "select * from indirect_details where caller_id=%d" % (self["id"])
		indirects = []
		cursor = self._mgr.getCursor()
		cursor.execute(sqlcmd)
		for row in cursor:
			indirects.append(IndirectDependency(row, self._mgr))
		return indirects

	def getDependedByIndirect(self):
		sqlcmd = "select * from indirect_details where callee_id=%d" % (self["id"])
		indirects = []
		cursor = self._mgr.getCursor()
		self._mgr.getCursor().execute(sqlcmd)
		for row in cursor:
			indirects.append(IndirectDependency(row, self._mgr))
		return indirects

	def getChipsetSdkDependedBy(self):
		res = []
		sqlcmd = "select id from dependencies where callee_id=%d and chipsetsdk=1" % (self["id"])
		cursor = self._mgr.getCursor()
		cursor.execute(sqlcmd)
		for row in cursor:
			res.append(self._mgr.get_dep_by_id(row[0]))
		return res

	def getChipsetSdkSymbols(self, cursor):
		sqlcmd = "select %s from call_details where parent_id=%d and chipsetsdk=1 order by name" % (self._symbolClass.QUERY_FIELDS_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getPlatformSdkDependedBy(self):
		res = []
		sqlcmd = "select id from dependencies where callee_id=%d and platformsdk=1" % (self["id"])
		cursor = self._mgr.getCursor()
		cursor.execute(sqlcmd)
		for row in cursor:
			res.append(self._mgr.get_dep_by_id(row[0]))
		return res

	def getPlatformSdkSymbols(self, cursor):
		sqlcmd = "select %s from call_details where parent_id=%d and platformsdk=1 order by name" % (self._symbolClass.QUERY_FIELDS_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def getExternalSymbols(self, cursor):
		sqlcmd = "select %s from call_details where parent_id=%d and external=1 order by name" % (self._symbolClass.QUERY_FIELDS_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		for row in cursor:
			symbols.append(self._symbolClass(row))
		return symbols

	def callByFunctionName(self, name, cursor):
		method_to_call = getattr(self, name)
		return method_to_call(cursor)

	def __str__(self):
		#return "%s deps:%s\n%s deps_indirect:%s" % (self["name"], self.getDepends(), self["name"], self.getIndirectDepends())
		return "%s:%d deps(%d) symbols(%d:%d) depsTotal(%d) dependedBy(%d)" % (self["name"], self["id"], self["provided"], self["needed"], len(self["deps"]), len(self["deps"]) + len(self["deps_indirect"]), len(self["dependedBy"]))

	def __repr__(self):
		return self.__str__()

class ElfSymbol(dict):
	QUERY_FIELDS = ("name", "library", "weak", "version", "demangle", "caller")
	QUERY_FIELDS_STR = ", ".join(QUERY_FIELDS)

	QUERY_FIELDS_SIMPLE = ("name", "library", "weak", "version", "demangle")
	QUERY_FIELDS_SIMPLE_STR = ", ".join(QUERY_FIELDS_SIMPLE)

	def __init__(self, row):
		for idx, field in enumerate(ElfSymbol.QUERY_FIELDS):
			if idx >= len(row):
				self[field] = ""
				break
			self[field] = row[idx]

class IndirectDependency(dict):
	def __init__(self, row, mgr):
		#self["id"] = idx
		cursor = mgr.getCursor()
		names = [description[0] for description in cursor.description]
		for idx, name in enumerate(names):
			self[name] = row[idx]
		self["caller"] = mgr.find_by_id(self["caller_id"])
		self["callee"] = mgr.find_by_id(self["callee_id"])
		self["external"] = 0
		self["platformsdk"] = 0
		self["chipsetsdk"] = 0
		self["depType"] = "Indirect"
		self["calls"] = 0

	def __eq__(self, other):
		if not isinstance(other, IndirectDependency):
			return NotImplemented

		return self["id"] == other["id"]#and self["name"] == other["name"]

	def __str__(self):
		return "(%s[%d] -%d-> %s[%d])" % (self["caller"]["name"], self["caller"]["id"], self["calls"], self["callee"]["name"], self["callee"]["id"])

	def __repr__(self):
		return self.__str__()

class Dependency(dict):
	def __init__(self, mgr, row):
		#self["id"] = idx
		self["caller"] = mgr.find_by_id(row[0])
		self["callee"] = mgr.find_by_id(row[1])
		self["calls"] = row[2]
		self["external"] = row[3]
		self["platformsdk"] = row[4]
		self["chipsetsdk"] = row[5]
		if self["external"]:
			self["depType"] = "External"
		else:
			self["depType"] = "Internal"
		if self["platformsdk"]:
			self["depType"] = "Platform SDK"
		if self["chipsetsdk"]:
			self["depType"] = "Chipset SDK"
		self["id"] = row[6]
		self._symbolClass = ElfSymbol

	def __eq__(self, other):
		if not isinstance(other, Dependency):
			return NotImplemented

		return self["id"] == other["id"]#and self["name"] == other["name"]

	def getCaller(self):
		return self["caller"]
	def getCallee(self):
		return self["callee"]
	def getCalls(self):
		return self["calls"]

	def __str__(self):
		return "(%s[%d] -%d-> %s[%d])" % (self["caller"]["name"], self["caller"]["id"], self["calls"], self["callee"]["name"], self["callee"]["id"])

	def __repr__(self):
		return self.__str__()

	def getDependedSymbols(self, cursor):
		sqlcmd = "select %s from call_details where dependence_id=%d" % (self._symbolClass.QUERY_FIELDS_STR, self["id"])
		symbols = []
		cursor.execute(sqlcmd)
		rows = cursor.fetchall()
		for row in rows:
			symbols.append(self._symbolClass(row))
		return symbols

class ElfModuleMgr(object):
	def __init__(self, cursor, elfClass=None, depClass=None):
		self._cursor = cursor
		self._maxDepth = 0
		self._maxTotalDepends = 0
		self._deps = []
		self._is64bit = False
		if elfClass:
			self._elfClass = elfClass
		else:
			self._elfClass = ElfModule
		if depClass:
			self._depClass = depClass
		else:
			self._depClass = Dependency

		self.__load_all_modules()

	def __del__(self):
		self._modules = []
		self._path_dict = {}
		self._deps = []

	def getCursor(self):
		return self._cursor

	def __load_all_modules(self):
		self._modules = []
		self._path_dict = {}

		sqlcmd = 'select * from modules order by id'
		self._cursor.execute(sqlcmd)
		for row in self._cursor:
			mod = self._elfClass(row, self)
			self._modules.append(mod)
			self._path_dict[mod["path"]] = mod
			if not self._is64bit and mod["path"].find("/lib64") > 0:
				self._is64bit = True

		self.__load_dependencies()

		#print("Max depth: %d, max total depends: %d" % (self._maxDepth, self._maxTotalDepends))

	def __load_dependencies(self):
		sqlcmd = 'select * from dependencies'
		self._cursor.execute(sqlcmd)
		for row in self._cursor:
			dep = self._depClass(self, row)
			caller = self.find_by_id(row[0])
			callee = self.find_by_id(row[1])
			caller["deps"].append(dep)
			callee["dependedBy"].append(dep)
			self._deps.append(dep)

	def get_dep_by_id(self, id):
		if id <= 0:
			return None
		if id > len(self._deps):
			return None
		return self._deps[id - 1]

	def find_by_id(self, id):
		if id <= 0:
			return None
		if id > len(self._modules):
			return None
		return self._modules[id - 1]

	def get_module_by_path(self, path):
		if path in self._path_dict:
			return self._path_dict[path]
		if not self._is64bit:
			return None
		path = path.replace("system/lib/", "system/lib64/")
		path = path.replace("vendor/lib/", "vendor/lib64/")
		if path in self._path_dict:
			return self._path_dict[path]
		return None

	def get_all(self, xargs=None):
		return self._modules

	def get_all_deps(self, xargs=None):
		return self._deps

	def output_all_deps_tree(self, maxLevel, minChild):
		for mod in self._modules.get_all():
			if mod.is_elf() and mod["depsLibCnt"] + mod["depsLibIndirectCnt"] > 0:
				print("output deps tree for [" + mod["name"] + "] now ...")
				self.output_deps_tree_by_width(mod["name"], maxLevel, minChild)
				self.output_deps_tree_by_depth(mod["name"], 14)

if __name__ == "__main__":
	conn = sqlite3.connect("archinfo.db")
	cursor = conn.cursor()

	modules = ElfModuleMgr(cursor)
	#elf = modules.find_by_id(392)
	elf = modules.find_by_id(1)
	#elf.output_deps_tree_by_depth(16)
	#elf.output_deps_tree_by_width(5, 10)
	needed = elf.callByFunctionName("getProvidedSymbols", cursor)
	#print("-----------------------")
	print(needed)
	#print("needed: %d" % (len(needed)))
	#print(modules.get_module_by_path("system/lib64/libbeget_proxy.z.so"))

